from __future__ import absolute_import

from distutils.version import LooseVersion
import numpy as np
import scipy  # Weird bug in new pytorch when import scipy after import torch

import paddorch
import paddorch as th
import builtins
from paddorch.utils import dlpack
# TODO dlpack

from ... import ndarray as nd
from ... import kernel as K
from ...function.base import TargetCode
from ...base import dgl_warning
import paddle
if LooseVersion(th.__version__) < LooseVersion("1.2.0"):
    dgl_warning("Detected an old version of PyTorch. Suggest using torch>=1.2.0 "
                "for the best experience.")


def data_type_dict():
    return {'float16': paddle.fluid.core.VarDesc.VarType.FP16,
            'float32': paddle.fluid.core.VarDesc.VarType.FP32,
            'float64': paddle.fluid.core.VarDesc.VarType.FP64,
            'uint8': paddle.fluid.core.VarDesc.VarType.UINT8,
            'int8': paddle.fluid.core.VarDesc.VarType.INT8,
            'int16': paddle.fluid.core.VarDesc.VarType.INT16,
            'int32': paddle.fluid.core.VarDesc.VarType.INT32,
            'int64': paddle.fluid.core.VarDesc.VarType.INT64,
            'bool': paddle.fluid.core.VarDesc.VarType.BOOL}


def cpu():
    return paddle.CPUPlace()


def tensor(data, dtype=None):
    if dtype is None:
        if isinstance(data,list) or isinstance(data,tuple):
            dtype="float32"
            if isinstance(data[0],int):
                dtype="int64"
        else:
            dtype=data.dtype
    return th.tensor(data, dtype=dtype)


def as_scalar(data):
    return data.item()


def get_preferred_sparse_format():
    """Get the preferred sparse matrix format supported by the backend.

    Different backends have their preferred backend. This info is useful when
    constructing a sparse matrix.
    """
    return "coo"


def sparse_matrix(data, index, shape, force_format=False):
    fmt = index[0]
    if fmt != 'coo':
        raise TypeError('Pytorch backend only supports COO format. But got %s.' % fmt)
    spmat = th.sparse_coo_tensor(index[1], data, shape)
    return spmat, None


def sparse_matrix_indices(spmat):
    return ('coo', spmat._indices())


def is_tensor(obj):
    return isinstance(obj, th.Tensor)


def shape(input):
    return tuple(input.shape)


def dtype(input):
    return input.dtype


def ndim(input):
    return input.dim()

import paddle
def context(input):
    if isinstance(input,th.Tensor):
        return input.device
    else:
        return paddle.fluid.core.Place()


def device_type(ctx):
    if "cpu" in str(ctx).lower():
        return "cpu"
    else:
        return "cuda"

def device_id(ctx):
    ##TODO: need better way
    return 0
    if ctx.index is None:
        return 0
    else:
        return ctx.index


def astype(input, ty):
    return input.type(ty)


# TODO to_dense()
def asnumpy(input):
    if isinstance(input, th.sparse.FloatTensor):
        return input.to_dense().numpy()
    else:
        return input.numpy().astype(str(input.dtype).replace("paddle.",""))


# TODO set_device
def copy_to(input, ctx):
    if 'cpu' in str(ctx).lower():
        return  th.convertTensor( input._copy_to(paddle.CPUPlace(),False))
    elif 'cuda' in str(ctx).lower():
        return  th.convertTensor(input._copy_to(paddle.CUDAPlace(0),False))
    else:
        raise RuntimeError('Invalid context', ctx)


def sum(input, dim, keepdims=False):
    return th.sum(input, dim=dim, keepdim=keepdims)


def reduce_sum(input):
    return input.sum()


def mean(input, dim):
    return th.mean(input, dim=dim)


def reduce_mean(input):
    return input.mean()


def max(input, dim):
    # NOTE: the second argmax array is not returned
    return th.max(input, dim=dim)


def reduce_max(input):
    return input.max()


def min(input, dim):
    # NOTE: the second argmin array is not returned
    return th.min(input, dim=dim)


def reduce_min(input):
    return input.min()


def argsort(input, dim, descending):
    return th.argsort(input, dim=dim, descending=descending)


def topk(input, k, dim, descending=True):
    return th.topk(input, k, dim, largest=descending)[0]


def argtopk(input, k, dim, descending=True):
    return th.topk(input, k, dim, largest=descending)[1]


def exp(input):
    return th.exp(input)


def softmax(input, dim=-1):
    return th.softmax(input, dim=dim)


def cat(seq, dim):
    return th.cat(seq, dim=dim)


def stack(seq, dim):
    return th.stack(seq, dim=dim)


def split(input, sizes_or_sections, dim):
    return th.split(input, sizes_or_sections, dim)


def repeat(input, repeats, dim):
    # return th.repeat_interleave(input, repeats, dim) # PyTorch 1.1
    if dim < 0:
        dim += input.dim()
    return th.flatten(th.stack([input] * repeats, dim=dim + 1), dim, dim + 1)


def gather_row(data, row_index):
    return th.index_select(data, 0, row_index)


def slice_axis(data, axis, begin, end):
    return th.narrow(data, axis, begin, end - begin)


def take(data, indices, dim):
    new_shape =  data.shape[:dim] + indices.shape  + data.shape[dim + 1:]
    if len(new_shape)==2 and new_shape[1]==1:
        new_shape=new_shape[0]
    return th.index_select(data, dim, indices.view(-1)).view(new_shape)


def narrow_row(x, start, stop):
    return x[start:stop]


def scatter_row(data, row_index, value):
    return data.index_copy(0,row_index, value)


def scatter_row_inplace(data, row_index, value):
    data[row_index] = value


def squeeze(input, dim):
    return th.squeeze(input, dim)


def unsqueeze(input, dim):
    return th.unsqueeze(input, dim)


def reshape(input, shape):
    return th.reshape(input, shape)


def swapaxes(input, axis1, axis2):
    return th.transpose(input, axis1, axis2)


def zeros(shape, dtype, ctx):
    return th.zeros(shape, dtype=dtype, device=ctx)


def zeros_like(input):
    return th.zeros_like(input)


def ones(shape, dtype, ctx):
    return th.ones(shape, dtype=dtype, device=ctx)


def uniform(shape, dtype, ctx, low, high):
    out = th.empty(shape, dtype=dtype, device=ctx)
    temp = th.uniform_(out.shape, low, high)
    th.assign(temp,out)
    return out


def pad_packed_tensor(input, lengths, value, l_min=None):
    old_shape = input.shape
    if isinstance(lengths, th.Tensor):
        max_len = as_scalar(lengths.max())
    else:
        max_len = builtins.max(lengths)

    if l_min is not None:
        max_len = builtins.max(max_len, l_min)

    batch_size = len(lengths)
    device = input.device
    x = input.new(batch_size * max_len, *old_shape[1:])
    x.fill_(value)
    index = []
    for i, l in enumerate(lengths):
        index.extend(range(i * max_len, i * max_len + l))
    index = th.tensor(index).to(device)
    return scatter_row(x, index, input).view(batch_size, max_len, *old_shape[1:])


def pack_padded_tensor(input, lengths):
    batch_size, max_len = input.shape[:2]
    device = input.device
    index = []
    for i, l in enumerate(lengths):
        index.extend(range(i * max_len, i * max_len + l))
    index = th.tensor(index).to(device)
    return gather_row(input.view(batch_size * max_len, -1), index)


def unsorted_1d_segment_sum(input, seg_id, n_segs, dim):
    y = th.zeros(n_segs, *input.shape[1:]).to(input)
    seg_id = seg_id.view((-1,) + (1,) * (input.dim() - 1)).expand_as(input)
    # y =y.scatter_add_(dim, seg_id, input)
    y2 = y.scatter_add(dim, seg_id, input)
    # y.set_value(y2)
    # y = y.scatter_add_(dim, seg_id, input)
    # y.stop_gradient=False
    th.copy(y2,y)
    return y


def unsorted_1d_segment_mean(input, seg_id, n_segs, dim):
    w = unsorted_1d_segment_sum(th.ones_like(seg_id), seg_id, n_segs, 0).to(input)
    w = w.clamp(min=1)  # remove 0 entries
    y = unsorted_1d_segment_sum(input, seg_id, n_segs, dim)
    y = y / w.view((-1,) + (1,) * (y.dim() - 1))
    return y


def boolean_mask(input, mask):
    return input[mask]


def equal(x, y):
    return x == y


def logical_not(input):
    return ~input


def unique(input):
    return th.unique(input)


def full_1d(length, fill_value, dtype, ctx):
    return th.full((length,), fill_value, dtype=dtype, device=ctx)


def nonzero_1d(input):
    x = th.nonzero(input).squeeze()
    return x if x.dim() == 1 else x.view(-1)


def sort_1d(input):
    # return th.sort(input)
    input=input.view(-1)
    order=th.argsort(input)
    return input[order],order


def arange(start, stop):
    return th.arange(start, stop, dtype=np.int64)


def rand_shuffle(arr):
    idx = th.randperm(len(arr))
    return arr[idx]


def zerocopy_to_dlpack(input):
    return dlpack.to_dlpack(input )


def zerocopy_from_dlpack(dlpack_tensor):
    return dlpack.from_dlpack(dlpack_tensor)


def zerocopy_to_numpy(input):
    # NOTE: not zerocopy
    return asnumpy(input)


def zerocopy_from_numpy(np_array):
    return th.as_tensor(np_array,dtype=np_array.dtype)


def zerocopy_to_dgl_ndarray(input):
    return nd.from_dlpack(dlpack.to_dlpack(input ))


def zerocopy_from_dgl_ndarray(input):
    return dlpack.from_dlpack(input.to_dlpack())


class BinaryReduce(th.autograd.Function):



    def forward(self, reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data, out_data,
                out_size, lhs_map, rhs_map, out_map):

        self.needs_input_grad=[False]*7
        if not lhs_data.stop_gradient:
            self.needs_input_grad[5]=True
            self.register_hook(lhs_data, "lhs_data")
        else:
            self.delete_hook("lhs_data")
        if not rhs_data.stop_gradient:
            self.needs_input_grad[6] = True
            self.register_hook(rhs_data, "rhs_data")
        else:
            self.delete_hook("rhs_data")
        out_data.stop_gradient=False
        lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data)
        rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data)
        feat_shape = K.infer_binary_feature_shape(binary_op, lhs_data_nd, rhs_data_nd)
        out_shape = feat_shape
        if binary_op == 'dot':
            out_shape = feat_shape[:-1]
        out_data_nd = zerocopy_to_dgl_ndarray(out_data)

        K.binary_op_reduce(
            reducer if reducer != 'mean' else 'sum',
            binary_op, graph, lhs, rhs, lhs_data_nd, rhs_data_nd,
            out_data_nd, lhs_map[0], rhs_map[0], out_map[0])
        # normalize if mean reducer
        # NOTE(zihao): this is a temporary hack and we should have better solution in the future.
        out_data.set_value(zerocopy_from_dgl_ndarray(out_data_nd))
        if reducer == 'mean':
            degs =paddorch.zeros((out_data.shape[0],),dtype=lhs_data.dtype) # lhs_data.new_empty((out_data.shape[0],))
            degs_nd = zerocopy_to_dgl_ndarray(degs)
            if lhs != TargetCode.DST:  # src or edge
                target = lhs
                n = lhs_data.shape[0]
                in_map = lhs_map[0]
            else:  # rhs != TargetCode.DST
                target = rhs
                n = rhs_data.shape[0]
                in_map = rhs_map[0]
            in_ones = lhs_data.new_ones((n,))
            in_ones_nd = zerocopy_to_dgl_ndarray(in_ones)
            K.copy_reduce(
                'sum', graph, target, in_ones_nd, degs_nd, in_map, out_map[0])
            # reshape


            degs = degs.reshape((out_data.shape[0],) + (1,) * (out_data.dim() - 1)).clamp(min=1)
            out_data = out_data / degs
            # del in_ones_nd,degs_nd
        else:
            degs = None
        # save_for_backward can only save variables
        self.backward_cache = (reducer, binary_op, graph, lhs, rhs, lhs_map,
                              rhs_map, out_map, feat_shape, degs)
        self.save_for_backward(lhs_data, rhs_data, out_data)
        out_data.register_hook(self.backward) # need add this line before the out+th.mean line
        out_data=out_data+th.mean(lhs_data )*0+th.mean(rhs_data )*0 ##trick to force create a connected graph!
        ret_var= paddorch.convertTensor(out_data)
        ret_var.stop_gradient=False
        # del lhs_data_nd,rhs_data_nd
        return ret_var


    def backward(self, grad_out):
        reducer, binary_op, graph, lhs, rhs, lhs_map, rhs_map, out_map, \
        feat_shape, degs = self.backward_cache
        lhs_data, rhs_data, out_data = self.saved_tensors
        lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data)
        rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data)
        out_data_nd = zerocopy_to_dgl_ndarray(out_data)
        grad_lhs = None
        grad_rhs = None
        if reducer == 'mean':
            grad_out = grad_out / degs
        grad_out_nd = zerocopy_to_dgl_ndarray(grad_out)
        if self.needs_input_grad[5]:
            grad_lhs = paddorch.zeros((lhs_data_nd.shape[0],) + feat_shape, dtype=grad_out.dtype) #grad_out.new_empty((lhs_data_nd.shape[0],) + feat_shape)
            grad_lhs.stop_gradient=False
            grad_lhs_nd=zerocopy_to_dgl_ndarray(grad_lhs)
            K.backward_lhs_binary_op_reduce(
                reducer if reducer != 'mean' else 'sum',
                binary_op, graph, lhs, rhs, lhs_data_nd, rhs_data_nd,
                out_data_nd, grad_out_nd, grad_lhs_nd,
                lhs_map[1], rhs_map[1], out_map[1])
            grad_lhs.set_value(zerocopy_from_dgl_ndarray(grad_lhs_nd))
            grad_lhs = _reduce_grad(grad_lhs, lhs_data_nd.shape)

        if self.needs_input_grad[6]:
            # grad_rhs = grad_out.new_empty((rhs_data_nd.shape[0],) + feat_shape)
            grad_rhs = paddorch.zeros((rhs_data_nd.shape[0],) + feat_shape, dtype=grad_out.dtype)
            grad_rhs.stop_gradient = False
            grad_rhs_nd=zerocopy_to_dgl_ndarray(grad_rhs)
            K.backward_rhs_binary_op_reduce(
                reducer if reducer != 'mean' else 'sum',
                binary_op, graph, lhs, rhs, lhs_data_nd, rhs_data_nd,
                out_data_nd, grad_out_nd, grad_rhs_nd,
                lhs_map[1], rhs_map[1], out_map[1])
            grad_rhs.set_value(zerocopy_from_dgl_ndarray(grad_rhs_nd))
            grad_rhs = _reduce_grad(grad_rhs, rhs_data_nd.shape)

        self.grad_cache["lhs_data"] = grad_lhs
        self.grad_cache["rhs_data"] = grad_rhs

        return grad_out
        # return None, None, None, None, None, paddorch.convertTensor(grad_lhs), paddorch.convertTensor(grad_rhs), None, None, None, \
        #        None, None


def binary_reduce(reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data,
                  out_size, lhs_map=(None, None), rhs_map=(None, None), out_map=(None, None)):
    lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data)
    rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data)
    feat_shape = K.infer_binary_feature_shape(binary_op, lhs_data_nd, rhs_data_nd)

    out_shape = feat_shape
    if binary_op == 'dot':
        out_shape = feat_shape[:-1]
    out_data = paddorch.zeros((out_size,) + out_shape,dtype=lhs_data.dtype)  #lhs_data.new_empty((out_size,) + out_shape)

    return BinaryReduce.apply(
        reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data, out_data,
        out_size, lhs_map, rhs_map, out_map)


class CopyReduce(th.autograd.Function):
    # @staticmethod
    def forward(self, reducer, graph, target, in_data, out_data, out_size, in_map,
                out_map):
        self.needs_input_grad=[False]*7
        ##clean up previous hook
        for key in self.grad_cache:
            self.delete_hook(key)

        if not in_data.stop_gradient:
            self.needs_input_grad[3]=True
            self.register_hook(in_data,"in_data")
        else:
            self.delete_hook("in_data")
        out_data.stop_gradient = False
        in_data_nd = zerocopy_to_dgl_ndarray(in_data)
        out_data_nd = zerocopy_to_dgl_ndarray(out_data)
        K.copy_reduce(
            reducer if reducer != 'mean' else 'sum',
            graph, target, in_data_nd, out_data_nd, in_map[0], out_map[0])
        out_data.set_value(zerocopy_from_dgl_ndarray(out_data_nd))

        # normalize if mean reducer
        # NOTE(zihao): this is a temporary hack and we should have better solution in the future.
        if reducer == 'mean':
            in_ones =paddorch.ones((in_data.shape[0],), dtype=in_data.dtype) # in_data.new_ones((in_data.shape[0],))
            degs =paddorch.zeros((out_data.shape[0],), dtype=in_data.dtype) #    in_data.new_empty((out_data.shape[0],))
            in_ones_nd = zerocopy_to_dgl_ndarray(in_ones)
            degs_nd = zerocopy_to_dgl_ndarray(degs)
            K.copy_reduce(
                'sum', graph, target, in_ones_nd, degs_nd, in_map[0], out_map[0])
            # reshape
            degs = degs.reshape((out_data.shape[0],) + (1,) * (out_data.dim() - 1)).clamp(min=1)
            out_data = out_data / degs
            # del in_ones_nd,degs_nd
        else:
            degs = None
        # save_for_backward can only save variables
        self.backward_cache = (reducer, graph, target, in_map, out_map, degs)
        self.save_for_backward(in_data, out_data)
        out_data.register_hook(self.backward) # need add this line before the out+th.mean line
        out_data = out_data + th.sum(in_data) * 0    ##trick to force create a connected graph!
        ret_var= paddorch.convertTensor(out_data)
        ret_var.stop_gradient=False
        # del out_data_nd,in_data_nd
        return ret_var


    def backward(self, grad_out):
        reducer, graph, target, in_map, out_map, degs = self.backward_cache
        in_data, out_data = self.saved_tensors
        in_data_nd = zerocopy_to_dgl_ndarray(in_data)
        out_data_nd = zerocopy_to_dgl_ndarray(out_data)
        grad_in = None
        if reducer == 'mean':
            grad_out = grad_out / degs
        grad_out_nd = zerocopy_to_dgl_ndarray(grad_out)
        if self.needs_input_grad[3]:
            grad_in = paddorch.zeros(in_data_nd.shape, dtype=grad_out.dtype) #grad_out.new_empty(in_data_nd.shape)
            grad_in_nd=zerocopy_to_dgl_ndarray(grad_in)
            K.backward_copy_reduce(
                reducer if reducer != 'mean' else 'sum',
                graph, target, in_data_nd, out_data_nd, grad_out_nd,
                grad_in_nd, in_map[1], out_map[1])
            grad_in.set_value(zerocopy_from_dgl_ndarray(grad_in_nd))

        self.grad_cache["in_data"] = grad_in
        return grad_out


def copy_reduce(reducer, graph, target, in_data, out_size, in_map=(None, None),
                out_map=(None, None)):
    out_data =paddorch.zeros((out_size,) + tuple(in_data.shape[1:]),dtype=in_data.dtype)  #in_data.new_empty((out_size,) + tuple(in_data.shape[1:]))
    return CopyReduce.apply(reducer, graph, target, in_data, out_data, out_size, in_map, out_map)


def _reduce_grad(grad, shape):
    """Reduce gradient on the broadcast dimension

    If there is broadcast in forward pass, gradients need to be reduced on
    broadcast dimension. This function checks the input tensor shape and
    gradient shape and perform the reduction.

    Parameters
    ----------
    grad: Tensor
        Gradient tensor
    shape: tuple
        Shape of input tensor

    Returns
    -------
    Tensor
    """
    grad_shape = grad.shape[1:]
    in_shape = shape[1:]
    if in_shape == tuple(grad_shape):
        # no need to reduce
        return grad
    num_to_squeeze = len(grad_shape) - len(in_shape)
    # pad inshape
    in_shape = (1,) * num_to_squeeze + in_shape
    reduce_idx = th.nonzero(th.tensor(grad_shape) - th.tensor(in_shape))+1
    grad = grad.sum(dim= reduce_idx.numpy().tolist()[0] , keepdim=True)
    return grad.view(shape)


def sync():
    # Pytorch performs computation synchronously, so no need for synchronization.
    pass
